{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Migration examples: Early Stopping\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/early_stopping\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/early_stopping.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/early_stopping.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/early_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "meUTrR4I6m1C"
      },
      "source": [
        "This notebook demonstrates how you can set up training early stopping in TF2 to match your early stopping strategy in TF1 with `tf.estimator` API.\n",
        "\n",
        "There are three ways to implement early stopping in TF2:\n",
        "- Use `tf.keras.callbacks.EarlyStopping` for `keras.Model`\n",
        "- Implement a custom callback for `keras.Model`\n",
        "- Implement the early stopping strategy manually if you [write your own training loop](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)\n",
        "\n",
        "If the TF1 code with estimator usage adopts an early stopping strategy that monitors a certain quantity, such as the validation loss, you can use the `tf.keras.callbacks.EarlyStopping` callback for `keras.Model.fit`. If the TF1 codes adopt other early stopping strategies, like limiting the training time to a threshold, you can manually implement it with custom callback or Keras custom training loop."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YdZSoIXEbhg-"
      },
      "source": [
        "## Setup\n",
        "\n",
        "First, you need to define a couple of necessary imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iE0vSfMXumKI"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as tf1\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4uXff1BEssdE"
      },
      "source": [
        "### TF1: Estimator.train/evaluate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JaHhhhW5o8lL"
      },
      "source": [
        "Prepare the dataset and model for `tf.estimator.Estimator`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lqe9obf7suIj"
      },
      "outputs": [],
      "source": [
        "def normalize_img(image, label):\n",
        "  return tf.cast(image, tf.float32) / 255., label\n",
        "\n",
        "def _input_fn(): \n",
        "  ds_train = tfds.load(\n",
        "    name='mnist',\n",
        "    split='train',\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True)\n",
        "  \n",
        "  ds_train = ds_train.map(\n",
        "      normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  ds_train = ds_train.batch(128)\n",
        "  ds_train = ds_train.repeat(100)\n",
        "  return ds_train\n",
        "\n",
        "def _eval_input_fn():\n",
        "  ds_test = tfds.load(\n",
        "    name='mnist',\n",
        "    split='test',\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True)\n",
        "  ds_test = ds_test.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  ds_test = ds_test.batch(128)\n",
        "  return ds_test\n",
        "\n",
        "def _model_fn(features, labels, mode):\n",
        "  flatten = tf1.layers.Flatten()(features)\n",
        "  features = tf1.layers.Dense(128, 'relu')(flatten)\n",
        "  logits = tf1.layers.Dense(10)(features)\n",
        "\n",
        "  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\n",
        "  optimizer = tf1.train.AdagradOptimizer(0.005)\n",
        "  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
        "  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hC_AY7KwqD0p"
      },
      "source": [
        "For `tf.estimator.Estimator` in TF1, early stopping was configured by setting up the early stopping hook. The `should_stop_fn` parameter of `tf1.estimator.experimental.make_early_stopping_hook` accepts a function without any argument. You can define different early stopping strategies by implementing their customized `should_stop_fn`.\n",
        "\n",
        "In the example below, the early stopping strategy is limiting the training time within 20 seconds."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HsOpjW5plH9Q"
      },
      "outputs": [],
      "source": [
        "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
        "\n",
        "start_time = time.time()\n",
        "max_train_seconds = 20\n",
        "def should_stop_fn():\n",
        "  return time.time() - start_time > max_train_seconds\n",
        "\n",
        "early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(\n",
        "    estimator=estimator,\n",
        "    should_stop_fn=should_stop_fn,\n",
        "    run_every_secs=1,\n",
        "    run_every_steps=None)\n",
        "\n",
        "train_spec = tf1.estimator.TrainSpec(\n",
        "    input_fn=_input_fn,\n",
        "    hooks=[early_stopping_hook])\n",
        "eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)\n",
        "tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEmzBjfnsxwT"
      },
      "source": [
        "### TF2: Keras training API"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GKwxnkIksPFW"
      },
      "source": [
        "Prepare the dataset and Keras model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "atVciNgPs0fw"
      },
      "outputs": [],
      "source": [
        "(ds_train, ds_test), ds_info = tfds.load(\n",
        "    'mnist',\n",
        "    split=['train', 'test'],\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True,\n",
        "    with_info=True,\n",
        ")\n",
        "\n",
        "ds_train = ds_train.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "ds_train = ds_train.batch(128)\n",
        "\n",
        "ds_test = ds_test.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "ds_test = ds_test.batch(128)\n",
        "\n",
        "model = tf.keras.models.Sequential([\n",
        "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    optimizer=tf.keras.optimizers.Adam(0.005),\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "559Goxp3tOMl"
      },
      "source": [
        "For `tf.keras.Model` in TF2, early stopping is configured by passing the callback, `tf.keras.callbacks.EarlyStopping`, to the `callbacks` parameter of `Model.fit`. The early stopping callback monitors a user-specified quantity to determine if it should end the training early. Please check the [API docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kip65sYBlKiu"
      },
      "outputs": [],
      "source": [
        "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n",
        "# Only around 25 epochs are run during training, instead of 100\n",
        "history = model.fit(\n",
        "    ds_train,\n",
        "    epochs=100,\n",
        "    validation_data=ds_test,\n",
        "    callbacks=[callback]\n",
        ")\n",
        "len(history.history['loss'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wCwZ4BA8jaHY"
      },
      "source": [
        "You can also implementa a custom early stopping callback. The training process will be stopped once `self.model.stop_training` is set to be True."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hns1fmwtjCg2"
      },
      "outputs": [],
      "source": [
        "class LimitTrainingTime(tf.keras.callbacks.Callback):\n",
        "  def __init__(self, max_time_s):\n",
        "    super().__init__()\n",
        "    self.max_time_s = max_time_s\n",
        "    self.start_time = None\n",
        "\n",
        "  def on_train_begin(self, logs):\n",
        "    self.start_time = time.time()\n",
        "\n",
        "  def on_train_batch_end(self, batch, logs):\n",
        "    now = time.time()\n",
        "    if now - self.start_time >  self.max_time_s:\n",
        "      self.model.stop_training = True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s5mIzDOAkUKA"
      },
      "outputs": [],
      "source": [
        "# Limit the training time to 30 seconds\n",
        "callback = LimitTrainingTime(30)\n",
        "history = model.fit(\n",
        "    ds_train,\n",
        "    epochs=100,\n",
        "    validation_data=ds_test,\n",
        "    callbacks=[callback]\n",
        ")\n",
        "len(history.history['loss'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kro_lKyEu60-"
      },
      "source": [
        "## TF2: Keras training API with Custom Training Loop"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g5LU0lebvuIk"
      },
      "source": [
        "Prepare the model, optimizer, and metrics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oTGxr0PwAiQ4"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.models.Sequential([\n",
        "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "optimizer = tf.keras.optimizers.Adam(0.005)\n",
        "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()\n",
        "val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zecsnqRxvy0Q"
      },
      "source": [
        "Define the train/test step. Use the `@tf.function` decorator for speedup by compiling the function into a static graph. More details on `tf.function` are available [here](https://www.tensorflow.org/guide/function)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s3w_55n0Ah7L"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "  with tf.GradientTape() as tape:\n",
        "      logits = model(x, training=True)\n",
        "      loss_value = loss_fn(y, logits)\n",
        "  grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "  optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "  train_acc_metric.update_state(y, logits)\n",
        "  train_loss_metric.update_state(y, logits)\n",
        "  return loss_value\n",
        "\n",
        "@tf.function\n",
        "def test_step(x, y):\n",
        "  logits = model(x, training=False)\n",
        "  val_acc_metric.update_state(y, logits)\n",
        "  val_loss_metric.update_state(y, logits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ZKS9ePGwd9r"
      },
      "source": [
        "In the training loop, you can implement any early stopping strategy manually. The example below is for monitoring if validation loss keeps decreasing every epoch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iZOzHqqSAkpK"
      },
      "outputs": [],
      "source": [
        "epochs = 100\n",
        "patience = 5\n",
        "wait = 0\n",
        "best = 0\n",
        "\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "    start_time = time.time()\n",
        "\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):\n",
        "      loss_value = train_step(x_batch_train, y_batch_train)\n",
        "      if step % 200 == 0:\n",
        "        print(\"Training loss at step %d: %.4f\" % (step, loss_value.numpy()))\n",
        "        print(\"Seen so far: %s samples\" % ((step + 1) * 128))        \n",
        "    train_acc = train_acc_metric.result()\n",
        "    train_loss = train_loss_metric.result()\n",
        "    train_acc_metric.reset_states()\n",
        "    train_loss_metric.reset_states()\n",
        "    print(\"Training acc over epoch: %.4f\" % (train_acc.numpy()))\n",
        "\n",
        "    for x_batch_val, y_batch_val in ds_test:\n",
        "      test_step(x_batch_val, y_batch_val)\n",
        "    val_acc = val_acc_metric.result()\n",
        "    val_loss = val_loss_metric.result()\n",
        "    val_acc_metric.reset_states()\n",
        "    val_loss_metric.reset_states()\n",
        "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
        "    print(\"Time taken: %.2fs\" % (time.time() - start_time))\n",
        "\n",
        "    # The early stopping strategy: stop the training if `val_loss` does not\n",
        "    # decrease for a certain number of epoches.\n",
        "    wait += 1\n",
        "    if val_loss > best:\n",
        "      best = val_loss\n",
        "      wait = 0\n",
        "    if wait >= patience:\n",
        "      break      "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "early_stopping.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
